package com.socket.secure.filter;

import cn.hutool.core.io.IoUtil;
import cn.hutool.core.util.ArrayUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.extra.servlet.ServletUtil;
import cn.hutool.json.JSONConfig;
import cn.hutool.json.JSONObject;
import cn.hutool.json.JSONUtil;
import com.socket.secure.constant.RequsetTemplate;
import com.socket.secure.exception.InvalidRequestException;
import com.socket.secure.exception.RequestNotSupportedException;
import com.socket.secure.util.AESUtil;
import com.socket.secure.util.Assert;
import com.socket.secure.util.HmacEnum;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.servlet.ReadListener;
import javax.servlet.ServletException;
import javax.servlet.ServletInputStream;
import javax.servlet.ServletRequest;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.Part;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.Enumeration;
import java.util.LinkedHashMap;
import java.util.Map;

/**
 * Secure servlet request wrapper
 */
final class SecureRequestWrapper extends HttpServletRequestWrapper {
    private static final Logger log = LoggerFactory.getLogger(SecureRequestWrapper.class);
    /**
     * Json config
     */
    private final JSONConfig config = JSONConfig.create().setIgnoreNullValue(false);
    /**
     * Request parameters
     */
    private final Map<String, String[]> params = new LinkedHashMap<>();
    /**
     * Request formdata files
     */
    private final Map<String, byte[]> files = new LinkedHashMap<>();
    /**
     * Whether to verify the file signature
     */
    private final boolean checkFile;
    /**
     * Request body
     */
    private JSONObject body = new JSONObject();
    /**
     * Base36 Link timestamp
     */
    private String timestamp;
    /**
     * Link signature
     */
    private String signature;

    /**
     * Secure servlet request wrapper
     *
     * @param request   {@link ServletRequest}
     * @param checkFile whether to verify the file
     */
    public SecureRequestWrapper(HttpServletRequest request, boolean checkFile) throws IOException, ServletException {
        super(request);
        this.checkFile = checkFile;
        this.initialize();
    }

    private void initialize() throws IOException, ServletException {
        ServletInputStream stream = super.getInputStream();
        if (checkFile && ServletUtil.isMultipart(this)) {
            for (Part part : super.getParts()) {
                String name = part.getSubmittedFileName();
                if (StrUtil.isNotEmpty(name)) {
                    files.put(name, IoUtil.readBytes(part.getInputStream()));
                }
            }
        }
        if (ServletUtil.isPostMethod(this)) {
            String body = IoUtil.read(stream, StandardCharsets.UTF_8);
            boolean empty = StrUtil.isEmpty(body);
            boolean isJSON = empty || JSONUtil.isTypeJSONObject(body);
            // Must be JSON type
            Assert.isTrue(isJSON, () -> new RequestNotSupportedException(RequsetTemplate.NOT_SUPPORTED_PAYLOAD, body));
            this.body = empty ? new JSONObject() : JSONUtil.parseObj(body, config);
        }
        this.params.putAll(super.getParameterMap());
        // pring-log
        log.debug("params: {}", params);
        log.debug("body: {}", body);
    }

    /**
     * try to decrypt all data
     *
     * @param signName request signature name
     * @throws InvalidRequestException decryption error
     */
    void decryptRequset(String signName) {
        // Decrypt FormData
        params.forEach(this::decrypt);
        // Decryption request body
        body.forEach(this::decrypt);
        // Find signature
        String signature = findSignature(signName);
        this.signature = signature.substring(0, 40);
        this.timestamp = signature.substring(40);
    }

    /**
     * Decrypt FormData
     */
    private void decrypt(String key, String[] values) {
        if (ArrayUtil.length(values) == 1) {
            values[0] = AESUtil.decrypt(values[0], getSession());
            params.put(key, values);
        }
    }

    /**
     * Decryption request body
     */
    private void decrypt(String key, Object value) {
        String val = AESUtil.decrypt(value.toString(), getSession());
        body.set(key, val);
    }

    /**
     * Find request signature
     */
    private String findSignature(String sign) {
        // find signatures in forms
        String[] values = params.remove(sign);
        if (values != null && values.length == 1) {
            return values[0];
        }
        // find the signature in the payload
        if (body.containsKey(sign)) {
            return (String) body.remove(sign);
        }
        // Can't find throws exception
        throw new InvalidRequestException(RequsetTemplate.REQUSET_SIGNATURE_NOT_FOUNT);
    }

    @Override
    public ServletInputStream getInputStream() {
        return new ServletInputStream() {
            private final ByteBuffer buffer = ByteBuffer.wrap(body.toString().getBytes());

            @Override
            public int read() {
                return isFinished() ? -1 : buffer.get();
            }

            @Override
            public boolean isFinished() {
                return buffer.remaining() == 0;
            }

            @Override
            public boolean isReady() {
                return false;
            }

            @Override
            public void setReadListener(ReadListener listener) {
                // implements too hard, ignore
                throw new UnsupportedOperationException();
            }
        };
    }

    @Override
    public String getParameter(String name) {
        return ArrayUtil.firstMatch(e -> true, params.get(name));
    }

    @Override
    public Map<String, String[]> getParameterMap() {
        return Collections.unmodifiableMap(params);
    }

    @Override
    public Enumeration<String> getParameterNames() {
        return Collections.enumeration(params.keySet());
    }

    @Override
    public String[] getParameterValues(String name) {
        return params.get(name);
    }

    @Override
    public BufferedReader getReader() {
        return new BufferedReader(new InputStreamReader(getInputStream()));
    }

    /**
     * Verify that the signature matches the data
     *
     * @return match returns true
     */
    boolean compareSign() {
        StringBuilder sb = new StringBuilder();
        // Generate params signature
        for (String[] values : params.values()) {
            if (ArrayUtil.isEmpty(values)) {
                continue;
            }
            if (values.length > 1) {
                sb.append(String.join("", values));
                continue;
            }
            sb.append(values[0]);
        }
        // Generate body signature
        for (Object value : body.values()) {
            String vstr = value.toString();
            if (JSONUtil.isTypeJSONArray(vstr)) {
                sb.append(JSONUtil.parseArray(vstr).join(""));
                continue;
            }
            sb.append(vstr);
        }
        if (checkFile) {
            for (Map.Entry<String, byte[]> entry : files.entrySet()) {
                String digest = HmacEnum.MD5.digestHex(this, entry.getValue());
                sb.append(HmacEnum.MD5.digestHex(this, entry.getKey() + digest));
            }
        }
        log.debug("secure join: " + sb);
        // check sign
        return HmacEnum.SHA1.digestHex(this, sb + timestamp).equals(signature);
    }

    /**
     * request timestamp
     */
    long getTimestamp() {
        return Long.parseLong(timestamp, 36);
    }

    /**
     * request signature
     */
    String sign() {
        return signature;
    }
}
